{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    }
   },
   "source": [
    "# Concept 1: Object-oriented Transformation\n",
    "\n",
    "[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/brainpy/examples/blob/main/docs/core_concept/brainpy_transform_concept.ipynb)\n",
    "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/brainpy/brainpy/blob/master/docs_version2/core_concept/brainpy_transform_concept.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "@[Chaoming Wang](https://github.com/chaoming0625)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Most computation in BrainPy relies on [JAX](https://jax.readthedocs.io/en/latest/). JAX has provided wonderful transformations, including differentiation, vecterization, parallelization and just-in-time compilation, for Python programs. If you are not familiar with it, please see its [documentation](https://jax.readthedocs.io/en/latest/).\n",
    "\n",
    "However, JAX only supports functional programming, i.e., transformations for Python functions. This is not what we want. Brain Dynamics Modeling need object-oriented programming.\n",
    "\n",
    "To meet this requirement, BrainPy defines the interface for object-oriented (OO) transformations. These OO transformations can be easily performed for any Python objects.\n",
    "\n",
    "In this section, let's talk about the BrainPy concept of object-oriented transformations."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-06T02:57:13.270243Z",
     "start_time": "2025-10-06T02:57:09.268648Z"
    }
   },
   "source": [
    "import brainpy as bp\n",
    "import brainpy.math as bm\n",
    "\n",
    "bm.set_platform('cpu')"
   ],
   "outputs": [],
   "execution_count": 1
  },
  {
   "cell_type": "code",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "ExecuteTime": {
     "end_time": "2025-10-06T02:57:13.283751Z",
     "start_time": "2025-10-06T02:57:13.276753Z"
    }
   },
   "source": [
    "bp.__version__"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'3.0.0'"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 2
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "## A simple example"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "Before diving into a real example, let's illustrate the OO transformation concept using a simple case."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "ExecuteTime": {
     "end_time": "2025-10-06T02:57:13.305977Z",
     "start_time": "2025-10-06T02:57:13.301364Z"
    }
   },
   "source": [
    "class Example:\n",
    "    def __init__(self):\n",
    "        self.static = 0\n",
    "        self.dyn = bm.Variable(bm.ones(1))\n",
    "\n",
    "    @bm.cls_jit  # JIT compiled function\n",
    "    def update(self, inp):\n",
    "        self.dyn.value = self.dyn * inp + self.static"
   ],
   "outputs": [],
   "execution_count": 3
  },
  {
   "cell_type": "code",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "ExecuteTime": {
     "end_time": "2025-10-06T02:57:13.352203Z",
     "start_time": "2025-10-06T02:57:13.327533Z"
    }
   },
   "source": [
    "example = Example()"
   ],
   "outputs": [],
   "execution_count": 4
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "To use OO transformations provided in BrainPy, we should keep three things in mind."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "\n",
    "1, All **dynamically changed variables** should be declared as\n",
    "\n",
    "   - instance of ``brainpy.math.Variable``, (like ``self.dyn``)\n",
    "   - or the function argument, (like ``inp``)"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "ExecuteTime": {
     "end_time": "2025-10-06T02:57:13.390501Z",
     "start_time": "2025-10-06T02:57:13.358416Z"
    }
   },
   "source": [
    "example.update(1.)\n",
    "example.dyn"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Variable(\n",
       "  value=ShapedArray(float32[1]),\n",
       "  _batch_axis=None,\n",
       "  axis_names=None\n",
       ")"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 5
  },
  {
   "cell_type": "code",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "ExecuteTime": {
     "end_time": "2025-10-06T02:57:13.402769Z",
     "start_time": "2025-10-06T02:57:13.397834Z"
    }
   },
   "source": [
    "example.update(2.)\n",
    "example.dyn"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Variable(\n",
       "  value=ShapedArray(float32[1]),\n",
       "  _batch_axis=None,\n",
       "  axis_names=None\n",
       ")"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 6
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "2, Other variables will be compiled as the **constants** during OO transformations. Changes made on these non-``Variable`` or non-``Argument`` will not show any impact after the function compiled."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "ExecuteTime": {
     "end_time": "2025-10-06T02:57:13.428554Z",
     "start_time": "2025-10-06T02:57:13.423394Z"
    }
   },
   "source": [
    "example.static = 100.  # not work\n",
    "example.update(1.)\n",
    "example.dyn"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Variable(\n",
       "  value=ShapedArray(float32[1]),\n",
       "  _batch_axis=None,\n",
       "  axis_names=None\n",
       ")"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 7
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "3, All OO transformations provided in BrainPy can be obtained from our [API documentation](../apis/auto/math.rst). Simply speaking, these OO transformations include:\n",
    "\n",
    "   - automatic differentiation transformations\n",
    "   - just-in-time compilations\n",
    "   - control flow transformations\n",
    "   - ..."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## A complex example: Training a network"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With the simple understanding of how OO transformations work, we can train a neural network model using the these transformations .\n",
    "\n",
    "In this training case, we want to teach the neural network to correctly classify a random array as two labels (`True` or `False`). That is, we have the training data:"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-06T02:57:13.916909Z",
     "start_time": "2025-10-06T02:57:13.447658Z"
    }
   },
   "source": [
    "num_in = 100\n",
    "num_sample = 256\n",
    "X = bm.random.rand(num_sample, num_in)\n",
    "Y = (bm.random.rand(num_sample) < 0.5).astype(float)"
   ],
   "outputs": [],
   "execution_count": 8
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We use a two-layer feedforward network:"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-06T02:57:14.680044Z",
     "start_time": "2025-10-06T02:57:13.922137Z"
    }
   },
   "source": [
    "class Linear(bp.BrainPyObject):\n",
    "    def __init__(self, n_in, n_out):\n",
    "        super().__init__()\n",
    "        self.num_in = n_in\n",
    "        self.num_out = n_out\n",
    "        init = bp.init.XavierNormal()\n",
    "        self.W = bm.Variable(init((n_in, n_out)))\n",
    "        self.b = bm.Variable(bm.zeros((1, n_out)))\n",
    "\n",
    "    def __call__(self, x):\n",
    "        return x @ self.W + self.b\n",
    "\n",
    "\n",
    "net = bp.Sequential(Linear(num_in, 20),\n",
    "                    bm.relu,\n",
    "                    Linear(20, 2))\n",
    "print(net)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sequential(\n",
      "  [0] <__main__.Linear object at 0x000001EB702A9E50>\n",
      "  [1] relu\n",
      "  [2] <__main__.Linear object at 0x000001EB702A8BF0>\n",
      ")\n"
     ]
    }
   ],
   "execution_count": 9
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here, we use a supervised learning training paradigm."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "ExecuteTime": {
     "end_time": "2025-10-06T02:57:14.697299Z",
     "start_time": "2025-10-06T02:57:14.686122Z"
    }
   },
   "source": [
    "class Trainer(object):\n",
    "    def __init__(self, net):\n",
    "        self.net = net\n",
    "        self.grad = bm.grad(self.loss, grad_vars=net.vars(), return_value=True)\n",
    "        self.optimizer = bp.optim.SGD(lr=1e-2, train_vars=net.vars())\n",
    "\n",
    "    @bm.cls_jit(inline=True)\n",
    "    def loss(self):\n",
    "        # shuffle the data\n",
    "        key = bm.random.split_key()\n",
    "        x_data = bm.random.permutation(X, key=key)\n",
    "        y_data = bm.random.permutation(Y, key=key)\n",
    "        # prediction\n",
    "        predictions = net(dict(), x_data)\n",
    "        # loss\n",
    "        l = bp.losses.cross_entropy_loss(predictions, y_data)\n",
    "        return l\n",
    "\n",
    "    @bm.cls_jit\n",
    "    def train(self):\n",
    "        grads, l = self.grad()\n",
    "        self.optimizer.update(grads)\n",
    "        return l"
   ],
   "outputs": [],
   "execution_count": 10
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-06T02:57:15.872108Z",
     "start_time": "2025-10-06T02:57:14.721421Z"
    }
   },
   "source": [
    "trainer = Trainer(net)\n",
    "\n",
    "for i in range(1, 4001):\n",
    "    ls = trainer.train()\n",
    "    if i % 400 == 0:\n",
    "        print(f'Train {i} epoch, loss = {ls:.4f}')"
   ],
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\codes\\projects\\BrainPy\\brainpy\\version2\\dynsys.py:341: UserWarning: \n",
      "From brainpy>=2.4.3, update() function no longer needs to receive a global shared argument.\n",
      "\n",
      "Instead of using:\n",
      "\n",
      "  def update(self, tdi, *args, **kwagrs):\n",
      "     t = tdi['t']\n",
      "     ...\n",
      "\n",
      "Please use:\n",
      "\n",
      "  def update(self, *args, **kwagrs):\n",
      "     t = bp.share['t']\n",
      "     ...\n",
      "\n",
      "  warnings.warn(_update_deprecate_msg, UserWarning)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train 400 epoch, loss = 0.6338\n",
      "Train 800 epoch, loss = 0.5872\n",
      "Train 1200 epoch, loss = 0.5400\n",
      "Train 1600 epoch, loss = 0.4924\n",
      "Train 2000 epoch, loss = 0.4477\n",
      "Train 2400 epoch, loss = 0.4063\n",
      "Train 2800 epoch, loss = 0.3679\n",
      "Train 3200 epoch, loss = 0.3329\n",
      "Train 3600 epoch, loss = 0.2994\n",
      "Train 4000 epoch, loss = 0.2679\n"
     ]
    }
   ],
   "execution_count": 11
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the above example, we have seen classical elements in a neural network training, such as \n",
    "\n",
    "- `net`: neural network\n",
    "- `loss`: loss function\n",
    "- `grad`: gradient function\n",
    "- `optimizer`: parameter optimizer\n",
    "- `train`: training step"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ``Variable`` and ``BrainPyObject``"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Although OO transformations in BrainPy do not explicitly require ``BrainPyObject``, defining a class as a subclass of ``BrainPyObject`` will gain many advantages.\n",
    "\n",
    "``BrainPyObject`` can be viewed as a container which contains all needed [Variable](../tutorial_math/variables.ipynb) for our computation."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](./imgs/net_with_two_linear.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the above example, ``Linear`` object has two ``Variable``: *W* and *b*. The ``net`` we defined is further composed of two ``Linear`` objects. We can expect that four variables can be retrieved from it."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-06T02:57:17.925977Z",
     "start_time": "2025-10-06T02:57:17.922443Z"
    }
   },
   "source": [
    "net.vars().keys()"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['Linear0.W', 'Linear0.b', 'Linear1.W', 'Linear1.b'])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 12
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "An important question is, **how to define `Variable` in a `BrainPyObject` so that we can retrieve all of them?**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Actually, all Variable instance which can be accessed by `self.` attribue can be retrived from a `BrainPyObject` recursively. \n",
    "No matter how deep the composition of ``BrainPyObject``, once `BrainPyObject` instance and their `Variable` instances can be accessed by `self.` operation, all of them will be retrieved. "
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-06T02:57:18.190524Z",
     "start_time": "2025-10-06T02:57:17.944733Z"
    }
   },
   "source": [
    "class SuperLinear(bp.BrainPyObject):\n",
    "    def __init__(self, ):\n",
    "        super().__init__()\n",
    "        self.l1 = Linear(10, 20)\n",
    "        self.v1 = bm.Variable(3)\n",
    "        \n",
    "sl = SuperLinear()"
   ],
   "outputs": [],
   "execution_count": 13
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-06T02:57:18.202830Z",
     "start_time": "2025-10-06T02:57:18.197584Z"
    }
   },
   "source": [
    "# retrieve Variable\n",
    "sl.vars().keys()"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['SuperLinear0.v1', 'Linear2.W', 'Linear2.b'])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 14
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-06T02:57:18.224917Z",
     "start_time": "2025-10-06T02:57:18.217929Z"
    }
   },
   "source": [
    "# retrieve BrainPyObject\n",
    "sl.nodes().keys()"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['SuperLinear0', 'Linear2'])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 15
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "However, we cannot access the ``BrainPyObject`` or ``Variable`` which is in a Python container (like tuple, list, or dict). For this case, we can use ``brainpy,math.NodeList`` or ``brainpy.math.VarList``:"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-06T02:57:18.242313Z",
     "start_time": "2025-10-06T02:57:18.237696Z"
    }
   },
   "source": [
    "class SuperSuperLinear(bp.BrainPyObject):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.ss = bm.NodeList([SuperLinear(), SuperLinear()])\n",
    "        self.vv = bm.VarList([bm.Variable(3)])"
   ],
   "outputs": [],
   "execution_count": 16
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-06T02:57:18.288920Z",
     "start_time": "2025-10-06T02:57:18.256031Z"
    }
   },
   "source": [
    "ssl = SuperSuperLinear()\n",
    "print(ssl.vars().keys())\n",
    "print(ssl.nodes().keys())"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_keys(['SuperSuperLinear0.vv-0', 'SuperLinear1.v1', 'SuperLinear2.v1', 'Linear3.W', 'Linear3.b', 'Linear4.W', 'Linear4.b'])\n",
      "dict_keys(['SuperSuperLinear0', 'SuperLinear1', 'SuperLinear2', 'Linear3', 'Linear4'])\n"
     ]
    }
   ],
   "execution_count": 17
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  },
  "latex_envs": {
   "LaTeX_envs_menu_present": true,
   "autoclose": false,
   "autocomplete": true,
   "bibliofile": "biblio.bib",
   "cite_by": "apalike",
   "current_citInitial": 1,
   "eqLabelWithNumbers": true,
   "eqNumInitial": 1,
   "hotkeys": {
    "equation": "Ctrl-E",
    "itemize": "Ctrl-I"
   },
   "labels_anchors": false,
   "latex_user_defs": false,
   "report_style_numbering": false,
   "user_envs_cfg": false
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": false,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {
    "height": "calc(100% - 180px)",
    "left": "10px",
    "top": "150px",
    "width": "245.75px"
   },
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
